import cv2
import numpy as np
from matplotlib import pyplot as plt


def batch_parse(batch_img,batch_pre, all_mask=False, plot=False, return_cnt=False):
    '''

    :param batch_img: shape should be (batch_size,W,H,channle)
    :param batch_pre: shape should be (batch_size,W,H,1)
    :param all_mask: see @parse_mask
    :param plot: see @parse_mask
    :param return_cnt: see @parse_mask
    :return: all yield
    '''
    for img,pre in zip(batch_img,batch_pre):
        cur = parse_mask(img,pre,all_mask,plot,return_cnt)
        for i in cur:
            yield i


def parse_mask(img, pre, all_mask=False, plot=False, return_cnt=False):
    '''
   :param img: nparray,W*H*channle
   :param pre: nparray,W*H
   :param all_mask: bool, return all mask img or yield mask,default False, mean yield every mask respectively
   :param plot: bool, whether plot or not,may used in jupyter or debug,default False
    :param return_cnt:bool,whether yield (mask_img,cnt) if True, default False.
    :return: a generator of all mask or every mask
    '''
    if np.max(pre) <= 1:
        pre *= 255
        pre = pre.astype(np.uint8)
    if np.max(img) <= 1:
        img *= 255
        img = img.astype(np.uint8)
    if pre.ndim == 3:
        pre = pre[:, :, 0]

    # 腐蚀操作
    kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], np.uint8)
    erosion = cv2.erode(pre, kernel, iterations=1)  # 任意通道，对每一个通道单独处理

    # 转换后发现边缘
    ret, thresh = cv2.threshold(erosion, 127, 255, 0)  # 接收W*H的灰度图，阈值转化
    im2, contours = cv2.findContours(thresh, 1, 2)  # 接收二值化的图片？W*H

    if all_mask:
        dc = cv2.drawContours(np.zeros(pre.shape), im2, -1, (1, 1, 1), -1)
        dilation = cv2.dilate(dc, kernel, iterations=6)  # 任意通道，对每一个通道单独处理
        dilation = np.stack([dilation, dilation, dilation])

        # 计算相应的掩膜图像
        result = img * dilation.transpose([1, 2, 0])
        if plot:
            plt.imshow(result / 255)
        if return_cnt:
            yield result, im2
        else:
            yield result
    else:
        for cnt in im2:
            # 分割后逐个膨胀操作
            dc = cv2.drawContours(np.zeros(pre.shape), [cnt], 0, (1, 1, 1), -1)
            dilation = cv2.dilate(dc, kernel, iterations=6)  # 任意通道，对每一个通道单独处理
            dilation = np.stack([dilation, dilation, dilation])

            # 计算相应的掩膜图像
            result = img * dilation.transpose([1, 2, 0])
            if plot:
                plt.imshow(result / 255)
            if return_cnt:
                yield result, np.array([cnt])
            else:
                yield result


def cut_rect(result,contour,offset = 5):
    '''
    return the cut img of whole image by cv2.boundingRect(contour)
    :param result: the whole Image
    :param contour: ndim is 3 or 4
    :param offset: expand image by x-offset:x+w+offset...
    :return: a cut image
    '''
    if contour.ndim == 4:
        contour = contour[0]

    x, y, w, h = cv2.boundingRect(contour)
    return result[y-5:y+h+5,x-5:x+w+5]

